# !/usr/bin/env python3
# coding=utf8
"""
这个脚本不严谨地验证了 talib.SMA 和 (参数为 talib.MA_Type.SMA 时的)talib.MA 是等价的,
talib.SMA(numpy.array([i * 0.1 for i in range(1, 9)]), 2)
talib.MA( numpy.array([i * 0.1 for i in range(1, 9)]), 2, talib.MA_Type.SMA)
"""
import json
import numpy
import random
import talib
from typing import Any, Dict, List, Set, Tuple, Type, Optional, Union, Callable


def gen_price_list(base: float, count: int) -> List[float]:
    """
    以 base 为基准, 每次大致在涨跌 10% 的范围内波动, 生成 count 个数据
    """
    values: List[float] = [float(base)]
    for _ in range(count - 1):
        dn: float = values[-1] * (1 - 0.1)
        up: float = values[-1] * (1 + 0.1)
        value: float = round(random.uniform(a=dn, b=up), 2)
        values.append(value)
    return values


def check_once(base: float, count: int, timeperiod: int):
    """
    验证 talib.SMA 和 (参数为 talib.MA_Type.SMA 时的)talib.MA 是等价的,
    验证 prices 和参数为 1 时的 talib.MA 是等价的,
    """
    base: float = round(base, 2)
    values: List[float] = gen_price_list(base=base, count=count)
    prices: numpy.ndarray = numpy.array(object=values)
    smaRet: numpy.ndarray = talib.SMA(prices, timeperiod)
    ma1Ret: numpy.ndarray = talib.MA(prices, timeperiod, talib.MA_Type.SMA)
    ma2Ret: numpy.ndarray = talib.MA(prices, 1, talib.MA_Type.SMA)
    if (ma1Ret.all() != smaRet.all()) or (ma2Ret.all() != prices.all()):
        stat: dict = {"base": base, "count": count, "timeperiod": timeperiod, "values": values, }
        raise RuntimeError(json.dumps(obj=stat))
    return True


def check(total: int):
    """"""
    counter: int = 0
    for _ in range(total):
        base: float = round(random.uniform(a=1.0, b=1000.0), 2)
        count: int = random.randint(1, 10000)
        timeperiod: int = random.randint(2, count if 2 <= count else 2)  # talib.SMA 的 timeperiod 的最小是 2
        check_once(base=base, count=count, timeperiod=timeperiod)
        counter += 1
        if counter % 1000 == 0:
            print(f"total={total}, counter={counter}, {round(counter/total*100,2)}%,")


if __name__ == "__main__":
    check(total=10_000_000)
